/*
 * @Date: 2022-08-01 22:59:59
 * @LastEditors: wangjun haodreams@163.com
 * @LastEditTime: 2024-12-25 16:38:24
 * @FilePath: \golib\nat\nat.go
 * @Description:
 */
package nat

import (
	"errors"
	"net"
	"time"

	"gitee.com/haodreams/libs/easy"
	"gitee.com/haodreams/libs/routine"
)

type Option func(*Nat)

// 设置缓存区大小，默认4k
func WithBufSize(size int) Option {
	return func(m *Nat) {
		if size < 512 {
			size = 512
		}
		if size > 65535 {
			size = 65535
		}
		m.bufSize = size
	}
}

// 连接描述信息
func WithID(id uint32) Option {
	return func(n *Nat) {
		n.ID = id
	}
}

// 连接描述信息
func WithDesc(desc string) Option {
	return func(n *Nat) {
		n.Desc = desc
	}
}

// 父节点信息
func WithParent(parent string) Option {
	return func(n *Nat) {
		n.Parent = parent
	}
}

// 连接描述信息
func WithLog(log func(v ...any)) Option {
	return func(n *Nat) {
		n.log = log
	}
}

// 连接断开后的回调操作
func WithOnLostConnect(onLostConnect func(*Nat)) Option {
	return func(n *Nat) {
		n.onLostConnect = onLostConnect
	}
}

// 连接断开后的回调操作
func WithOnCreate(onCreate func(*Nat)) Option {
	return func(n *Nat) {
		n.onCreate = onCreate
	}
}

type Nat struct {
	bufSize       int
	dest          net.Conn
	src           net.Conn
	chanConn      chan net.Conn //NewNatFromSource 调用才有效
	log           func(v ...any)
	onLostConnect func(*Nat)
	onCreate      func(*Nat)
	ID            uint32
	Parent        string //父节点
	Desc          string //说明
	CreateTime    int64  //连接建立时间
	LastTime      int64  //数据最新刷新时间
	DestHost      string //目的主机
	SrcHost       string //源主机
	RecvSize      uint64 //接收字节数
	SendSize      uint64 //发送字节数
	Closed        bool   //连接是否关闭
	Msg           string //状态
}

// 新建一个nat服务
func NewNat(dest, src net.Conn, opts ...Option) (m *Nat) {
	m = new(Nat)
	m.bufSize = 4096
	m.DestHost = "-"
	m.SrcHost = "-"

	m.DestHost = dest.RemoteAddr().String()
	dest.SetReadDeadline(time.Time{})
	src.SetReadDeadline(time.Time{})
	m.SrcHost = src.RemoteAddr().String()

	m.dest = dest
	m.src = src
	for _, o := range opts {
		o(m)
	}
	if m.onCreate != nil {
		m.onCreate(m)
	}
	m.CreateTime = time.Now().Unix()
	go m.Nat()
	return m
}

// 新建一个nat服务
func NewNatFromSource(src net.Conn, opts ...Option) (m *Nat) {
	m = new(Nat)
	m.bufSize = 4096
	m.DestHost = "-"
	m.SrcHost = "-"

	m.SrcHost = src.RemoteAddr().String()
	src.SetReadDeadline(time.Time{})
	m.src = src
	for _, o := range opts {
		o(m)
	}
	m.Msg = "连接准备中"
	m.chanConn = make(chan net.Conn, 1)
	if m.onCreate != nil {
		m.onCreate(m)
	}
	return m
}

func (m *Nat) PutDestConn(conn net.Conn) (err error) {
	if m.Closed {
		err = errors.New("连接已关闭")
		return
	}
	if cap(m.chanConn) == 0 {
		err = errors.New("资源已释放")
		return
	}
	m.chanConn <- conn
	return
}

// 等待目标连接出现
func (m *Nat) WaitDestConn(msTimeout int64) (err error) {
	if m.chanConn == nil {
		err = errors.New("必须先调用函数:NewNatFromSource()")
		return
	}
	m.Msg = "等待连接响应"
	select {
	case dest := <-m.chanConn:
		if m.Closed {
			err = errors.New("连接已关闭")
			return
		}
		m.DestHost = dest.RemoteAddr().String()
		dest.SetReadDeadline(time.Time{})

		m.dest = dest
		m.CreateTime = time.Now().Unix()
	case <-time.After(time.Millisecond * time.Duration(msTimeout)):
		m.Msg = "连接响应超时"
		err = errors.New(m.Msg)
		m.Close()
	}
	return
}

func (m *Nat) Close() {
	if m.Closed {
		return
	}
	m.Closed = true
	if m.dest != nil {
		m.dest.Close()
	}
	m.src.Close()
	m.Msg = "通讯已关闭"
	if m.onLostConnect != nil {
		m.onLostConnect(m)
	}
}

func (m *Nat) Nat() {
	if m.log != nil {
		m.log(m.ID, m.Desc, m.SrcHost, "<====>", m.DestHost, "通道已建立")
		defer func() {
			m.log(m.ID, m.Desc, m.SrcHost, "<====>", m.DestHost, "通道已关闭",
				"S:", easy.BeautifySize(int64(m.SendSize)),
				"R:", easy.BeautifySize(int64(m.RecvSize)),
			)
		}()
	}
	m.nat()
}
func (m *Nat) nat() {
	m.LastTime = time.Now().Unix()
	m.Msg = "通讯中"
	go func() {
		for !m.Closed && routine.IsRunning() {
			t := time.Now().Unix() - 900 //15分钟没有数据通讯断开连接
			if m.LastTime < t {
				break
			}
			routine.Sleep(2000)
		}
		m.Close()
	}()

	go func() {
		buf := make([]byte, m.bufSize)
		for routine.IsRunning() {
			n, err := m.src.Read(buf)
			if err != nil {
				break
			}
			m.SendSize += uint64(n)
			if n != 0 {
				m.LastTime = time.Now().Unix()
				_, err = m.dest.Write(buf[:n])
				if err != nil {
					break
				}
			}
		}
		m.Close()
	}()

	buf := make([]byte, m.bufSize)
	for routine.IsRunning() {
		n, err := m.dest.Read(buf)
		if err != nil {
			break
		}
		m.RecvSize += uint64(n)
		if n > 0 {
			m.LastTime = time.Now().Unix()
			_, err = m.src.Write(buf[:n])
			if err != nil {
				break
			}
		}
	}
	m.Close()
}
